import math

import numpy as np
import networkx as nx
from sklearn.preprocessing import normalize

import torch
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.utils import to_networkx, coalesce, to_undirected
from torch_geometric.utils.convert import from_networkx
from torch_scatter import scatter_add

from seeds import development_seed

DATA_PATH = 'data'

def get_feature_mask(rates, group_mask, n_nodes, n_features):
    """ Return mask of shape [n_nodes, n_features] indicating whether each feature is present or missing"""
    probs = torch.Tensor([1 - rates[0]]).repeat(n_nodes, n_features)
    probs[group_mask] = 1 - rates[1]
    return torch.bernoulli(probs).bool()

# degree-based importance sampling for missing features
def get_degree_based_feature_mask(edge_index, edge_weight, n_nodes, n_features):
    """ Return mask of shape [n_nodes, n_features] indicating whether each feature is present or missing"""
    row, col = edge_index[0], edge_index[1]
    deg = scatter_add(edge_weight, col, dim=0, dim_size=n_nodes)
    normalized_deg = deg / deg.sum()
    probs = normalized_deg.reshape(-1, 1).repeat(1, n_features)
    return torch.bernoulli(probs).bool()


def get_group_mask(rate, n_nodes):
    """ Return mask of shape [n_nodes] indicating the group to which each node belongs"""
    return torch.bernoulli(torch.Tensor([rate]).repeat(n_nodes)).bool()


def set_train_val_test_split(
        seed: int,
        data: Data,
        dataset_name: str,
        split_idx: int = None) -> Data:

    if dataset_name in ["Cora", "CiteSeer", "PubMed", "Photo", "Computers", "CoauthorCS", "CoauthorPhysics"]:
        # Use split from "Diffusion Improves Graph Learning" paper, which selects 20 nodes for each class to be in the training set
        num_val = 5000 if dataset_name == "CoauthorCS" else 1500
        data = set_per_class_train_val_test_split(seed=seed, data=data, num_val=num_val, num_train_per_class=20, split_idx=split_idx)
    elif dataset_name in ["OGBN-Arxiv", "OGBN-Products"]:
        # OGBN datasets have pre-assigned split
        data.train_mask = split_idx['train']
        data.val_mask = split_idx['valid']
        data.test_mask = split_idx['test']
    elif dataset_name in ['credit', 'german', 'bail', "Twitch", "Deezer-Europe", "FB100", "Actor"] or dataset_name.startswith('sbm'):
        # Datasets from "New Benchmarks for Learning on Non-Homophilous Graphs". They use uniform 50/25/25 split
        data = set_uniform_train_val_test_split(seed, data, train_ratio=0.5, val_ratio=0.25)
    elif dataset_name == 'Syn-Cora':
        # Datasets from "Beyond Homophily in Graph Neural Networks: Current Limitations and Effective Designs". They use uniform 25/25/50 split
        data = set_uniform_train_val_test_split(seed, data, train_ratio=0.25, val_ratio=0.25)
    elif dataset_name == 'MixHopSynthetic':
        # Datasets from "MixHop: Higher-Order Graph Convolutional Architectures via Sparsified Neighborhood Mixing". They use uniform 33/33/33 split
        data = set_uniform_train_val_test_split(seed, data, train_ratio=0.33, val_ratio=0.33)
    else:
        raise ValueError(f"We don't know how to split the data for {dataset_name}")

    return data

def set_per_class_train_val_test_split(
        seed: int,
        data: Data,
        num_val: int = 1500,
        num_train_per_class: int = 20,
        split_idx: int = None) -> Data:

    if split_idx is None: 
        rnd_state = np.random.RandomState(development_seed)
        num_nodes = data.y.shape[0]
        development_idx = rnd_state.choice(num_nodes, num_val, replace=False)
        test_idx = [i for i in np.arange(num_nodes) if i not in development_idx]

        train_idx = []
        rnd_state = np.random.RandomState(seed)
        for c in range(data.y.max() + 1):
            class_idx = development_idx[np.where(data.y[development_idx].cpu() == c)[0]]
            train_idx.extend(rnd_state.choice(class_idx, num_train_per_class, replace=False))

        val_idx = [i for i in development_idx if i not in train_idx]

        data.train_mask = get_mask(train_idx, num_nodes)
        data.val_mask = get_mask(val_idx, num_nodes)
        data.test_mask = get_mask(test_idx, num_nodes)

    else:
        data.train_mask = split_idx['train']
        data.val_mask = split_idx['valid']
        data.test_mask = split_idx['test']

    return data

def set_uniform_train_val_test_split(
        seed: int,
        data: Data,
        train_ratio: float = 0.8,
        val_ratio: float = 0.1) -> Data:
    rnd_state = np.random.RandomState(seed)
    num_nodes = data.y.shape[0]
    
    # Some nodes have labels -1 (i.e. unlabeled), so we need to exclude them
    labeled_nodes = torch.where(data.y != -1)[0]
    num_labeled_nodes = labeled_nodes.shape[0]
    num_train = math.floor(num_labeled_nodes * train_ratio)
    num_val = math.floor(num_labeled_nodes * val_ratio)
    
    idxs = list(range(num_labeled_nodes))
    # Shuffle in place
    rnd_state.shuffle(idxs)

    train_idx = idxs[:num_train]
    val_idx = idxs[num_train:num_train+num_val]
    test_idx = idxs[num_train+num_val:]

    train_idx = labeled_nodes[train_idx]
    val_idx = labeled_nodes[val_idx]
    test_idx = labeled_nodes[test_idx]

    data.train_mask = get_mask(train_idx, num_nodes)
    data.val_mask = get_mask(val_idx, num_nodes)
    data.test_mask = get_mask(test_idx, num_nodes)
#     data.train_mask = train_idx
#     data.val_mask = val_idx
#     data.test_mask = test_idx

    # Set labels of unlabeled nodes to 0, otherwise there is an issue in label propagation (which does one-hot encoding of all labels)
    # This labels are not used since these nodes are excluded from all masks, do it doesn't affect any results
    data.y[data.y == -1] = 0

    return data

def get_mask(idx, num_nodes):
    mask = torch.zeros(num_nodes, dtype=torch.bool)
    mask[idx] = 1
    return mask.bool()

def get_random_walk(edge_index, edge_weight, n_nodes, n_features):
    edge_index_list = [edge_index] * n_features
    row, col = edge_index[0], edge_index[1]
    deg = scatter_add(edge_weight, col, dim=0, dim_size=n_nodes)
    deg_inv = deg.pow(-1)
    DA = deg_inv[col] * edge_weight
    edge_weight_list = [DA] * n_features
    return edge_index_list, edge_weight_list

def get_global_mean(feature_mask, num_nodes):
    edge_index_list = []
    edge_weight_list = []
    for f in range(feature_mask.size(1)):
        K_idx = torch.nonzero(feature_mask[:, f]).reshape(-1)
        U_idx = torch.nonzero(~feature_mask[:, f]).reshape(-1)
        
        row = torch.cat([K_idx, K_idx.repeat_interleave(U_idx.size(0))]).reshape(1, -1)
        col = torch.cat([K_idx, U_idx.repeat(K_idx.size(0))]).reshape(1, -1)
        
        assert row.size(1) == K_idx.size(0) + K_idx.size(0) * U_idx.size(0)
        
        edge_index = torch.cat([row, col])
        edge_index = coalesce(edge_index)
        edge_index_list.append(edge_index)
        
        edge_weight = torch.ones((edge_index.size(1), )).to(edge_index.device)
        edge_weight_list.append(get_random_walk(edge_index, edge_weight, int(num_nodes), feature_mask.size(1))[1][0])
    return edge_index_list, edge_weight_list


def get_normalized_degree(edge_index, edge_weight, num_nodes):
    row, col = edge_index[0], edge_index[1]
    deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
    return deg.pow(-0.5)


def get_contraction_coefficient(edge_index, edge_weight, \
                                beta, feature_mask, group_mask, num_nodes):
    row, col = edge_index[0], edge_index[1]
    alpha = 1
    for g in [1, 0]:
        row_mask = group_mask[row] == 1 - g
        col_mask = (group_mask[col] == g) & (feature_mask[col] == 0)
        sum_deg_U = torch.sum(edge_weight[row_mask & col_mask])

        col_mask = (group_mask[col] == g) & (feature_mask[col] == 1)
        sum_deg_K = torch.sum(edge_weight[row_mask & col_mask])

        count_g = (group_mask == g).sum()
        alpha -= ((sum_deg_U.float() + beta * sum_deg_K.float()) / count_g).item()
    
    return abs(alpha)


def get_max_deviations(x):
    deviations = []
    for f in range(x.size(1)):
        deviations += [torch.max(torch.abs(x[:, f].reshape(-1, 1) - x[:, f]))]
    return torch.tensor(deviations, device=x.device)


def get_adj_row_sum(edge_index, edge_weight, n_nodes):
    """
    Get weighted out degree for nodes. This is equivalent to computing the sum of the rows of the weighted adjacency matrix.
    """
    row = edge_index[0]
    return scatter_add(edge_weight, row, dim=0, dim_size=n_nodes)

def get_adj_col_sum(edge_index, edge_weight, n_nodes):
    """
    Get weighted in degree for nodes. This is equivalent to computing the sum of the columns of the weighted adjacency matrix.
    """
    col = edge_index[1]
    return scatter_add(edge_weight, col, dim=0, dim_size=n_nodes)

def row_normalize(edge_index, edge_weight, n_nodes):
    row_sum = get_adj_row_sum(edge_index, edge_weight, n_nodes)
    row_idx = edge_index[0]
    return edge_weight / row_sum[row_idx]

def col_normalize(edge_index, edge_weight, n_nodes):
    col_sum = get_adj_col_sum(edge_index, edge_weight, n_nodes)
    col_idx = edge_index[1]
    return edge_weight / col_sum[col_idx]

def compute_laplacian_eigenvectors(data, normalized_laplacian=True):
    G = to_networkx(data).to_undirected()
    L = nx.normalized_laplacian_matrix(G, list(range(G.number_of_nodes()))) if normalized_laplacian else nx.laplacian_matrix(G)
    
    return np.linalg.eigh(L.todense())

def get_random_graph(n_nodes=20, m=2, n_features=10, k=5):
    G = nx.barabasi_albert_graph(n_nodes, m)
    L = nx.linalg.laplacianmatrix.normalized_laplacian_matrix(G).todense()
    _, eigenvectors = np.linalg.eigh(L)
    
    # Signal is linear combination of first k Laplacian eigenvectors (with random coeffiecients)
    x = torch.Tensor(normalize(eigenvectors[:, :k] @ np.random.uniform(size=(k, n_features)), axis=0))
    
    data = from_networkx(G)
    data.x = x
    
    return data

def compute_neighbor_correlation(edge_index, x):
    # Compute row-normalized adjacency
    edge_weight = torch.ones(edge_index.shape[1])
    edge_weight = row_normalize(edge_index=edge_index, edge_weight=edge_weight, n_nodes=x.shape[0])
    adj_row_norm = torch.sparse.FloatTensor(edge_index, values=edge_weight)
    
    neighbors_mean_x = torch.sparse.mm(adj_row_norm, x)
    # Compute dot product between each node and the mean of its neighbors
    # (which is equal to the mean of the dot product of a node with its neighbors)
    dot_product = torch.sum(x * neighbors_mean_x, axis=1)
    # Compute correlation by dividing by product of stds
    # TODO: This is actually computing dot-product, we need to change it to compute correlation
    correlation = dot_product
    
    # Take the mean over all the nodes
    return torch.mean(correlation).item()

    
